#-*-coding:utf-8-*-
# date:2021-04-15
# author: likecy
# function : train

import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import  sys

from tensorboardX import SummaryWriter
from utils.model_utils import *
from utils.common_utils import *
from data_iter.datasets import *
from loss.loss import FocalLoss
from models.build_model import *
from sklearn.metrics import precision_score, recall_score, f1_score
import torch.nn.functional as F
import cv2
import time
import json
import os
os.environ['TZ'] = "Asia/Shanghai"  #修改时间为上海时间

MODEL_NAMES ={
    'alexnet': Alexnet,
    'googlenet':Googlenet,
    'resnet18':Resnet18,
    'resnet50': Resnet50,
    'resnet101': Resnet101,
    'resnet152':Resnet152,
    'resnext101_32x8d': Resnext101_32x8d,
    'resnext101_32x16d': Resnext101_32x16d,
    'resnext101_32x48d': Resnext101_32x48d,
    'resnext101_32x32d': Resnext101_32x32d,
    'densenet121': Densenet121,
    'densenet169': Densenet169,
    'moblienetv2': Mobilenetv2,
    'squeezenet1_0':Squeezenet1_0,
    'squeezenet1_1':Squeezenet1_1,
    'shufflenet_v2_x0_5':Shufflenet_v2_x0_5,
    'shufflenet_v2_x1_0':Shufflenet_v2_x1_0,
    'shufflenet_v2_x1_5':Shufflenet_v2_x1_5,
    'shufflenet_v2_x2_0':Shufflenet_v2_x2_0,
    'efficientnet-b7': Efficientnet,
    'efficientnet-b0': Efficientnet,
    'efficientnet-b8': Efficientnet
}





def tester(ops,epoch,model,criterion,
    test_split,test_split_label,
    use_cuda):
    #
    print('\n------------------------->>> tester traival loss')
    loss_test=[]
    y_true = []
    y_pred = []
    with torch.no_grad():
        # train loss
        for i in range(len(test_split)):
            file = test_split[i]
            label = test_split_label[i]

            img = cv2.imread(file)
            # 输入图片预处理
            if ops.fix_res:
                img_ = letterbox(img,size_=ops.img_size[0],mean_rgb = (128,128,128))
            else:
                img_ = cv2.resize(img, (ops.img_size[1],ops.img_size[0]), interpolation = cv2.INTER_CUBIC)

            img_ = img_.astype(np.float32)
            img_ = (img_-128.)/256.

            img_ = img_.transpose(2, 0, 1)
            img_ = torch.from_numpy(img_)
            img_ = img_.unsqueeze_(0)

            label_ = np.array(label)
            label_ = torch.from_numpy(label_).float()

            if use_cuda:
                img_ = img_.cuda()  # (bs, 3, h, w)
                labels_ = label_.cuda()  # (bs, 3, h, w)

            pre_ = model(img_.float())

            outputs = F.softmax(pre_, dim=1)
            outputs = outputs[0]

            output = outputs.cpu().detach().numpy()
            output = np.array(output)

            max_index = np.argmax(output)
            score_ = output[max_index]
            y_true.append(label)
            y_pred.append(max_index)
            print('true{}   pre {}   --->>>  confidence {}'.format(label,max_index, score_))

            loss = criterion(output, labels_)
            loss_test.append(loss.item())

    p = precision_score(y_true, y_pred)  # 输出结果0.5
    r = recall_score(y_true, y_pred)  # 输出结果0.333
    f1 = f1_score(y_true, y_pred)  # 输出0.4
    print("*"*20+"第："+str(epoch)+"个epoch的测试集结果 start"+"*"*20)
    print("精确度：", p)
    print("召回率:", r)
    print("F1值", f1)
    print('loss : {}'.format(np.mean(loss_test)))
    print("*" * 20 + "测试集结果 end" + "*" * 20)
    return np.mean(loss_test),p,r,f1


def trainer(ops,f_log):
    try:
        os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS

        if ops.log_flag:
            sys.stdout = f_log

        set_seed(ops.seed)

        test_split, test_split_label  = split_test_datasets(ops)
        # train_path =  ops.train_path
        # # num_classes = len(os.listdir(ops.train_path)) # 模型类别个数
        num_classes = ops.num_classes
        print('num_classes : ',num_classes)
        #---------------------------------------------------------------- 构建模型
        print('use model : %s'%(ops.model))
        use_cuda = torch.cuda.is_available()

        writer = SummaryWriter('/content/drive/log/'+ops.model+'_log') #tensorbord 日志
        #####build the network model
        print('/**********************************************/')
        print('***************** Training {} ***************** '.format(ops.model))

        if not ops.model.startswith('efficientnet'):
            model_ = MODEL_NAMES[ops.model](num_classes=ops.num_classes,pretrained=ops.pretrained)
                # 冻结前边一部分层不训练
            if ops.pretrained:
                ct = 0
                for child in model_.children():
                    ct += 1
                    # print(child)
                    if ct < 8:
                        print(child)
                        for param in child.parameters():
                            param.requires_grad = False
        else:
            model_ = MODEL_NAMES[ops.model](model_name=ops.model,num_classes=ops.num_classes,pretrained=ops.pretrained)
              # 冻结前边一部分层不训练
            if ops.pretrained:
                c = 0
                for name, p in model_.named_parameters():
                    c += 1
                    print(name)
                    if c >=700:
                        break
                    p.requires_grad = False
       
        device = torch.device("cuda:0" if use_cuda else "cpu")
        model_ = model_.to(device)
       
        if os.access(ops.fintune_model,os.F_OK):# checkpoint  接着训练
            chkpt = torch.load(ops.fintune_model, map_location=device)
            model_.load_state_dict(chkpt)
            print('load fintune model : {}'.format(ops.fintune_model))
        # print(model_)# 打印模型结构

        # Dataset
        dataset = LoadImagesAndLabels(path = ops.train_path,img_size=ops.img_size,flag_agu=ops.flag_agu,fix_res = ops.fix_res,val_split = val_split,have_label_file = ops.have_label_file)
        print('len train datasets : %s'%(dataset.__len__()))
        # Dataloader
        dataloader = DataLoader(dataset,
                                batch_size=ops.batch_size,
                                num_workers=ops.num_workers,
                                shuffle=True,
                                pin_memory=False,
                                drop_last = True)
        
        
        # 优化器设计
        # optimizer_Adam = torch.optim.Adam(model_.parameters(), lr=init_lr, betas=(0.9, 0.99),weight_decay=1e-6)
        optimizer_SGD = optim.SGD(model_.parameters(), lr=ops.init_lr, momentum=0.9, weight_decay=ops.weight_decay)# 优化器初始化
        optimizer = optimizer_SGD
        print('use optimizer : optimizer_SGD')

        # 损失函数
        if 'focalLoss' == ops.loss_define:
            criterion = FocalLoss(num_class = num_classes)
            print('use loss : focalLoss')
        else:
            criterion = nn.CrossEntropyLoss()#CrossEntropyLoss() 是 softmax 和 负对数损失的结合

        step = 0
        idx = 0

        # 变量初始化
        best_loss = np.inf
        loss_mean = 0. # 损失均值
        loss_idx = 0. # 损失计算计数器
        flag_change_lr_cnt = 0 # 学习率更新计数器
        init_lr = ops.init_lr # 学习率

        epochs_loss_dict = {}

        for epoch in range(0, ops.epochs):
            if ops.log_flag:
                sys.stdout = f_log
            print('\nepoch %d ------>>>'%epoch)
            model_.train()
            # 学习率更新策略
            if loss_mean!=0.:
                if best_loss > (loss_mean/loss_idx):
                    flag_change_lr_cnt = 0
                    best_loss = (loss_mean/loss_idx)
                else:
                    flag_change_lr_cnt += 1

                    if flag_change_lr_cnt > 10:
                        init_lr = init_lr*ops.lr_decay
                        set_learning_rate(optimizer, init_lr)
                        flag_change_lr_cnt = 0

            loss_mean = 0. # 损失均值
            loss_idx = 0. # 损失计算计数器

            for i, (imgs_, labels_) in enumerate(dataloader):

                if use_cuda:
                    imgs_ = imgs_.cuda()  # pytorch 的 数据输入格式 ： (batch, channel, height, width)
                    labels_ = labels_.cuda()

                output = model_(imgs_.float())

                loss = criterion(output, labels_)
                loss_mean += loss.item()
                loss_idx += 1.
                if i%10 == 0:
                    acc = get_acc(output, labels_)
                    train_f1_sorce = f1_sorce(labels_,output)
                    loc_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
                    print('  %s - %s - epoch [%s/%s] (%s/%s):'%(loc_time,ops.model,epoch,ops.epochs,i,int(dataset.__len__()/ops.batch_size)),\
                    'mean loss : %.6f, loss : %.6f'%(loss_mean/loss_idx,loss.item()),\
                    ' acc : %.4f'%acc,' f1_sorce : %.4f'%train_f1_sorce,' lr : %.5f'%init_lr,' bs :',ops.batch_size,\
                    ' img_size: %s x %s'%(ops.img_size[0],ops.img_size[1]),' best_loss: %.4f'%best_loss)
                    writer.add_scalars(ops.model+ '_Train_val_loss', {'train_loss': loss.item(),'acc':acc,'f1_sorce':train_f1_sorce}, epoch)
                # 计算梯度
                loss.backward()
                # 优化器对模型参数更新
                optimizer.step()
                # 优化器梯度清零
                optimizer.zero_grad()
                step += 1

                # 到达一定loss限制自动退出训练并保存权重
                if(loss.item()<= ops.end_loss_point):
                    torch.save(model_.state_dict(),
                               ops.model_exp + '{}-size-{}_epoch-{}-pre_breackd.pth'.format(ops.model, ops.img_size[0],
                                                                                        epoch))
                    break;

                # 一个 epoch 保存连词最新的 模型
                # if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0:
                if i%(1000) == 0 and i > 0:
                    torch.save(model_.state_dict(), ops.model_exp + 'latest.pth')
            # test test_interval个间隔区间保存模型权重并进行测试集测试出loss
            if (epoch%ops.test_interval==0):
                torch.save(model_.state_dict(),
                           ops.model_exp + '{}-size-{}_epoch-{}.pth'.format(ops.model, ops.img_size[0], epoch))
                model_.eval()
                loss_test,test_acc,test_r,test_f1 = tester(ops,epoch,model_,criterion,
                        test_split,test_split_label,
                        use_cuda)
                writer.add_scalars(ops.model+ '_test_interval', {'test_loss': loss_test,'test_acc':test_acc,"test_return":test_r,"test_f1_sorce":test_f1}, epoch)
        torch.save(model_.state_dict(), ops.model_exp + '{}-size-{}_epoch-{}-trained.pth'.format(ops.model,ops.img_size[0],epoch))
            # set_seed(random.randint(0,65535))

    except Exception as e:
        print('Exception : ',e) # 打印异常
        print('Exception  file : ', e.__traceback__.tb_frame.f_globals['__file__'])# 发生异常所在的文件
        print('Exception  line : ', e.__traceback__.tb_lineno)# 发生异常所在的行数

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='西南交通大学-1806-图像分类任务')
    parser.add_argument('--seed', type=int, default = 123,
        help = 'seed') # 设置随机种子
    parser.add_argument('--model_exp', type=str, default = '/content/drive/MyDrive/model_exp',
        help = 'model_exp') # 模型输出文件夹
    parser.add_argument('--model', type=str, default = 'densenet121',
        help = 'model :参考models/build_model 中的模型 MODEL_NAMES') # 模型类型

    '''
        注意以下4个参数与具体分类任务数据集，息息相关
    '''
    #---------------------------------------------------------------------------------
    parser.add_argument('--train_path', type=str, default = './ct-datasets/',
        help = 'train_path') # 训练集路径
    parser.add_argument('--test_path', type=str, default='./ct-datasets/',
                        help='test_path')  # 测试集路径
    parser.add_argument('--num_classes', type=int , default = 2,
        help = 'num_classes') #  分类类别个数,gesture 配置为 14 ， Stanford Dogs 配置为 120
    parser.add_argument('--end_loss_point', type=float, default=0.05,
                        help='end_loss_point')  # loss 边界中止训练
    #---------------------------------------------------------------------------------
    parser.add_argument('--have_label_file', type=bool, default = False,
        help = 'have_label_file') # 是否有配套的标注文件解析才能生成分类样本， Stanford Dogs 配置为 True
    parser.add_argument('--GPUS', type=str, default = '0',
        help = 'GPUS') # GPU选择
    parser.add_argument('--val_factor', type=float, default = 0.1,
        help = 'val_factor') # 从训练集中分离验证集对应的比例
    parser.add_argument('--test_interval', type=int, default = 5,
        help = 'test_interval') # 训练集和测试集 计算 loss 间隔
    parser.add_argument('--pretrained', type=bool, default = False,
        help = 'imageNet_Pretrain') # 初始化学习率
    parser.add_argument('--fintune_model', type=str, default = ' ',
        help = 'fintune_model') # fintune model
    parser.add_argument('--loss_define', type=str, default = 'focalLoss',
        help = 'define_loss') # 损失函数定义
    parser.add_argument('--init_lr', type=float, default = 1e-3,
        help = 'init_learningRate') # 初始化学习率
    parser.add_argument('--lr_decay', type=float, default = 0.96,
        help = 'learningRate_decay') # 学习率权重衰减率
    parser.add_argument('--weight_decay', type=float, default = 1e-6,
        help = 'weight_decay') # 优化器正则损失权重
    parser.add_argument('--batch_size', type=int, default = 48,
        help = 'batch_size') # 训练每批次图像数量
    parser.add_argument('--dropout', type=float, default = 0.5,
        help = 'dropout') # dropout
    parser.add_argument('--epochs', type=int, default = 200,
        help = 'epochs') # 训练周期
    parser.add_argument('--num_workers', type=int, default = 6,
        help = 'num_workers') # 训练数据生成器线程数
    parser.add_argument('--img_size', type=tuple , default = (224,224),
        help = 'img_size') # 输入模型图片尺寸
    parser.add_argument('--flag_agu', type=bool , default = True,
        help = 'data_augmentation') # 训练数据生成器是否进行数据扩增
    parser.add_argument('--fix_res', type=bool , default = False,
        help = 'fix_resolution') # 输入模型样本图片是否保证图像分辨率的长宽比
    parser.add_argument('--clear_model_exp', type=bool, default = False,
        help = 'clear_model_exp') # 模型输出文件夹是否进行清除
    parser.add_argument('--log_flag', type=bool, default = True,
        help = 'log flag') # 是否保存训练 log

    #--------------------------------------------------------------------------
    args = parser.parse_args()# 解析添加参数
    #--------------------------------------------------------------------------
    mkdir_(args.model_exp, flag_rm=args.clear_model_exp)
    loc_time = time.localtime()
    args.model_exp = args.model_exp + '/' +args.model+ '_'+time.strftime("%Y-%m-%d_%H-%M-%S", loc_time)+'/'
    mkdir_(args.model_exp, flag_rm=args.clear_model_exp)

    f_log = None
    if args.log_flag:
        f_log = Logger(filename = args.model_exp+'/train_{}.log'.format(time.strftime("%Y-%m-%d_%H-%M-%S",loc_time)))
        print("开始记录日志～")
        sys.stdout = f_log

    print('---------------------------------- log : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", loc_time)))
    print('\n/******************* {} ******************/\n'.format(parser.description))

    unparsed = vars(args) # parse_args()方法的返回值为namespace，用vars()内建函数化为字典
    for key in unparsed.keys():
        print('{} : {}'.format(key,unparsed[key]))

    unparsed['time'] = time.strftime("%Y-%m-%d %H:%M:%S", loc_time)

    fs = open(args.model_exp+'train_ops.json',"w",encoding='utf-8')
    json.dump(unparsed,fs,ensure_ascii=False,indent = 1)
    fs.close()

    trainer(ops = args,f_log = f_log)# 模型训练

    if args.log_flag:
        sys.stdout = f_log
    print('well done : {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
